{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4e3fc631-fcfb-4e87-a8bc-e2d376ffeb99",
   "metadata": {},
   "source": [
    "https://scikit-learn.org/1.1/modules/decomposition.html#principal-component-analysis-pca"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c687cfd-1406-42d5-9b43-98351e7b0bf6",
   "metadata": {},
   "source": [
    "## Training and Logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "16c5dd19-9892-421b-9256-bf1aed24b9bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((150, 4), (150,))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "X, y = load_iris(return_X_y=True)\n",
    "(X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6ec0a774-4a5a-4707-99a5-5e05f476bd77",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(array([5.1, 3.5, 1.4, 0.2]), 0), (array([4.9, 3. , 1.4, 0.2]), 0)]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[(X[0], y[0]), (X[1], y[1])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "157c2d54-36a4-4cf6-9bfc-fdd2e23d6d09",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-2.68412563,  0.31939725])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "pca = PCA(n_components=2)\n",
    "X_r = pca.fit(X).transform(X)\n",
    "X_r[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "31108cb4-1a18-44b0-a246-81e1855dcf25",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023/03/29 13:14:01 WARNING mlflow.sklearn: Model was missing function: predict. Not logging python_function flavor!\n",
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.\n",
      "  warnings.warn(\"Setuptools is replacing distutils.\")\n",
      "2023/03/29 13:14:03 WARNING mlflow.utils.environment: Failed to resolve installed pip version. ``pip`` will be added to conda.yaml environment spec without a version specifier.\n",
      "Registered model 'da_pca' already exists. Creating a new version of this model...\n",
      "2023/03/29 13:14:04 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_pca, version 3\n",
      "Created version '3' of model 'da_pca'.\n",
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/liga/mlflow/logger.py:137: UserWarning: value of rikai.output.schema is None or empty and will not be populated to MLflow\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import getpass\n",
    "\n",
    "import mlflow\n",
    "from liga.sklearn.mlflow import log_model\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "\n",
    "mlflow_tracking_uri = \"sqlite:///mlruns.db\"\n",
    "mlflow.set_tracking_uri(mlflow_tracking_uri)\n",
    "\n",
    "# train a model\n",
    "with mlflow.start_run() as run:\n",
    "    ####\n",
    "    # Part 1: Train the model and register it on MLflow\n",
    "    ####\n",
    "    model = PCA(n_components=2)\n",
    "    model.fit(X, y)\n",
    "    registered_model_name = f\"{getpass.getuser()}_pca\"\n",
    "    log_model(model, registered_model_name=registered_model_name)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21c90c88-e271-4449-91c9-8aed48035ad8",
   "metadata": {},
   "source": [
    "## Apply the model on large scale dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4a24d667-f6b7-4f44-b431-5072e92cd53e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-03-29 13:14:04,133 INFO Rikai (__init__.py:121): setting spark.sql.extensions to net.xmacs.liga.spark.RikaiSparkSessionExtensions\n",
      "2023-03-29 13:14:04,133 INFO Rikai (__init__.py:121): setting spark.driver.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true\n",
      "2023-03-29 13:14:04,133 INFO Rikai (__init__.py:121): setting spark.executor.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true\n",
      "2023-03-29 13:14:04,134 INFO Rikai (__init__.py:121): setting spark.jars to https://github.com/komprenilo/liga/releases/download/v0.3.0/liga-spark321-assembly_2.12-0.3.0.jar\n",
      "23/03/29 13:14:05 WARN Utils: Your hostname, tubi resolves to a loopback address: 127.0.1.1; using 192.168.31.32 instead (on interface wlp0s20f3)\n",
      "23/03/29 13:14:05 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "23/03/29 13:14:08 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
      "23/03/29 13:14:09 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n",
      "23/03/29 13:14:09 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.\n",
      "23/03/29 13:14:09 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+------+----------------+-------+\n",
      "|name            |plugin|uri             |options|\n",
      "+----------------+------+----------------+-------+\n",
      "|mlflow_sklearn_m|      |mlflow:///da_pca|       |\n",
      "+----------------+------+----------------+-------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from example import spark\n",
    "\n",
    "from liga.mlflow import CONF_MLFLOW_TRACKING_URI\n",
    "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"false\")\n",
    "spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)\n",
    "spark.sql(f\"\"\"\n",
    "CREATE OR REPLACE MODEL mlflow_sklearn_m LOCATION 'mlflow:///{registered_model_name}';\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "spark.sql(\"show models\").show(1, vertical=False, truncate=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cad70b07-1272-47d0-97bf-91511c8a6b68",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- mlflow_sklearn_m: array (nullable = true)\n",
      " |    |-- element: float (containsNull = true)\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mlflow_sklearn_m</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[-2.6841256618499756, 0.3193972408771515]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            mlflow_sklearn_m\n",
       "0  [-2.6841256618499756, 0.3193972408771515]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from liga.numpy.sql import literal\n",
    "\n",
    "result = spark.sql(f\"\"\"\n",
    "select\n",
    "  ML_PREDICT(mlflow_sklearn_m, {literal(X[0])})\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "result.printSchema()\n",
    "result.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "743ed33b-411e-4cfb-8cfa-221bd9eb0658",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
